#!/usr/bin/python
# -*- encoding: utf-8 -*-

from optparse import OptionParser
import sys,os,re,unicodedata
from math import log
import openfst

parser = OptionParser("""
usage: %prog [options] ...

""")
parser.add_option("-u","--utf8",help="words only",action="store_true")
parser.add_option("-w","--words",help="words only",action="store_true")
parser.add_option("-W","--wordfst",help="put words on arcs (for debugging)",action="store_true")
parser.add_option("-v","--verbose",help="verbose",action="store_true")
parser.add_option("-n","--n",help="n-gram",type=int,default=3)
parser.add_option("-o","--output",help="output file",default="ngram.fst")
parser.add_option("-m","--minimize",help="perform minimization",action="store_true")
(options,args) = parser.parse_args()

if len(args)<1:
    parser.print_usage()
    sys.exit(0)

n = options.n

class Counter(dict):
    def __getitem__(self,index):
        return self.setdefault(index,0)

words = Counter()
ngrams = Counter()
chars = Counter()

files = []
for file in args:
    s = open(file).read().decode("utf-8")
    print "# file:",file,"chars:",len(s)
    s.strip()
    files.append(s)
line = u" ".join(files)
line += u" "
line = re.sub(ur"\s*\n\*"," ",line)
line = re.sub(ur'\s+',' ',line)
line = re.sub(ur'[_#]','',line)

sep_re = re.compile(ur'([^A-Za-z0-9äöüÄÖÜß]+)')

ws = sep_re.split(line)

ngram = [u""]*n
                
for w in ws:
    for c in w: chars[c] += 1
    assert len(ngram)==n
    del ngram[0]
    ngram.append(w)
    t = tuple(ngram)
    if options.verbose: 
        print ("["+("|".join(t))+"]").encode("utf-8")
    words[t[-1]] += 1
    ngrams[t] += 1

print "words",len(words.keys())
print "ngrams",len(ngrams.keys())

print "# creating symbol table"

EPS = 0

symtab = openfst.SymbolTable("chars")
symtab.AddSymbol("EPS",EPS)
for c in chars.keys():
    c = unichr(ord(c))
    if c=='"':
        symtab.AddSymbol("''",ord(c))
    elif options.utf8:
        symtab.AddSymbol(c.encode("utf-8"),ord(c))
    else:
        desc = str(c) if ord(c)<128 else "U+%04x"%ord(c)
        symtab.AddSymbol(desc,ord(c))

fst = openfst.StdVectorFst()

nstates = 0
states = {}

initial = fst.AddState()
fst.SetStart(initial)
final = fst.AddState()
fst.SetFinal(final,0)

for t in ngrams.keys():
    for s in [t[:-1],t[1:]]:
        if not s in states:
            state = fst.AddState()
            states[s] = state
            fst.AddArc(initial,0,0,0,state)
            fst.AddArc(state,0,0,0,final)

print "# adding transitions"

def add_sep(frm,to,s):
    state = frm
    for i in range(len(s)):
        next = to if i==len(s)-1 else fst.AddState()
        c = ord(s[i])
        fst.AddArc(state,c,c,0,next)
        # can start and stop anywhere inside a separator
        fst.AddArc(initial,0,0,0,state)
        fst.AddArc(state,0,0,0,final)
        state = next

def add_word(frm,to,s):
    state = frm 
    for i in range(len(s)):
        next = to if i==len(s)-1 else fst.AddState()
        c = ord(s[i])
        cost = 0.0
        fst.AddArc(state,c,c,cost,next)
        state = next

for ngram in ngrams:
    assert len(ngram)==options.n
    w = ngram[-1]
    s0 = tuple(ngram[:-1])
    s1 = tuple(ngram[1:])
    frm = states[s0]
    to = states[s1]
    if options.verbose: print s0,repr(w),s1
    if sep_re.search(w):
        add_sep(frm,to,w)
    else:
        add_word(frm,to,w)

if options.minimize:
    print "# minimizing"
    det = openfst.StdVectorFst()
    openfst.Determinize(fst,det)
    openfst.Minimize(det)
    fst = det

fst.SetInputSymbols(symtab)
fst.SetOutputSymbols(symtab)
fst.Write(options.output)
